yerang commited on
Commit
3995cc7
1 Parent(s): 8603f30

Create gradio_pipeline_stf.py

Browse files
Files changed (1) hide show
  1. src/gradio_pipeline_stf.py +132 -0
src/gradio_pipeline_stf.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Pipeline for gradio
5
+ """
6
+ import gradio as gr
7
+ from .config.argument_config import ArgumentConfig
8
+ from .live_portrait_pipeline import LivePortraitPipeline
9
+ from .utils.io import load_img_online
10
+ from .utils.rprint import rlog as log
11
+ from .utils.crop import prepare_paste_back, paste_back
12
+ # from .utils.camera import get_rotation_matrix
13
+
14
+ from .utils.video import merge_audio_video
15
+
16
+
17
+ def update_args(args, user_args):
18
+ """update the args according to user inputs
19
+ """
20
+ for k, v in user_args.items():
21
+ if hasattr(args, k):
22
+ setattr(args, k, v)
23
+ return args
24
+
25
+ class GradioPipeline(LivePortraitPipeline):
26
+
27
+ def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
28
+ super().__init__(inference_cfg, crop_cfg)
29
+ # self.live_portrait_wrapper = self.live_portrait_wrapper
30
+ self.args = args
31
+
32
+ def execute_video(
33
+ self,
34
+ input_image_path,
35
+ input_video_path,
36
+ flag_relative_input,
37
+ flag_do_crop_input,
38
+ flag_remap_input
39
+ audio_path=None,
40
+ ):
41
+ """ for video driven potrait animation
42
+ """
43
+ if input_image_path is not None and input_video_path is not None:
44
+ args_user = {
45
+ 'source_image': input_image_path,
46
+ 'driving_info': input_video_path,
47
+ 'flag_relative': flag_relative_input,
48
+ 'flag_do_crop': flag_do_crop_input,
49
+ 'flag_pasteback': flag_remap_input,
50
+ }
51
+ # update config from user input
52
+ self.args = update_args(self.args, args_user)
53
+ self.live_portrait_wrapper.update_config(self.args.__dict__)
54
+ self.cropper.update_config(self.args.__dict__)
55
+ # video driven animation
56
+ video_path, video_path_concat = self.execute(self.args)
57
+ # gr.Info("Run successfully!", duration=2)
58
+
59
+ #return video_path, video_path_concat,
60
+
61
+
62
+ if audio_path is not None:
63
+ merged_video_path = video_path[:-3]+"_audio.mp4"
64
+ merge_audio_video(video_path, audio_path, merged_video_path)
65
+ video_path = merged_video_path
66
+
67
+ gr.Info("Run successfully!", duration=2)
68
+ if video_path.endswith(".jpg"):
69
+ #return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True)
70
+ return video_path, video_path_concat
71
+ else:
72
+ #return output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
73
+ return video_path, video_path_concat
74
+
75
+ else:
76
+ raise gr.Error("Please upload the source portrait and driving video 🤗🤗🤗", duration=5)
77
+
78
+ def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop = True):
79
+ """ for single image retargeting
80
+ """
81
+ # disposable feature
82
+ f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
83
+ self.prepare_retargeting(input_image, flag_do_crop)
84
+
85
+ if input_eye_ratio is None or input_lip_ratio is None:
86
+ raise gr.Error("Invalid ratio input 💥!", duration=5)
87
+ else:
88
+ x_s_user = x_s_user.to("cuda")
89
+ f_s_user = f_s_user.to("cuda")
90
+ # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
91
+ combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user)
92
+ eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
93
+ # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
94
+ combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
95
+ lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
96
+ num_kp = x_s_user.shape[1]
97
+ # default: use x_s
98
+ x_d_new = x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
99
+ # D(W(f_s; x_s, x′_d))
100
+ out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
101
+ out = self.live_portrait_wrapper.parse_output(out['out'])[0]
102
+ out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
103
+ # gr.Info("Run successfully!", duration=2)
104
+ return out, out_to_ori_blend
105
+
106
+
107
+ def prepare_retargeting(self, input_image, flag_do_crop = True):
108
+ """ for single image retargeting
109
+ """
110
+ if input_image is not None:
111
+ # gr.Info("Upload successfully!", duration=2)
112
+ inference_cfg = self.live_portrait_wrapper.cfg
113
+ ######## process source portrait ########
114
+ img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=1) # n=1 means do not trim the pixels
115
+ log(f"Load source image from {input_image}.")
116
+ crop_info = self.cropper.crop_single_image(img_rgb)
117
+ if flag_do_crop:
118
+ I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
119
+ else:
120
+ I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
121
+ x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
122
+ # R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
123
+ ############################################
124
+ f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
125
+ x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
126
+ source_lmk_user = crop_info['lmk_crop']
127
+ crop_M_c2o = crop_info['M_c2o']
128
+ mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
129
+ return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
130
+ else:
131
+ # when press the clear button, go here
132
+ raise gr.Error("Please upload a source portrait as the retargeting input 🤗🤗🤗", duration=5)